import astropy.io.fits as fits
import math
import numpy as np
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
import astropy.units as u
import pickle
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
plt.rcParams.update({'font.size': 20})

def rms(x, axis=None):
    return np.sqrt(np.mean(x**2, axis=axis))

def get_spectrum(imagename,coord,radius,exclude_inner=False):
   image=fits.open(imagename)
   header=image[0].header
   data_copy=image[0].data.copy()
   wcs=WCS(header)
   xy=wcs.world_to_pixel(coord,225.0*u.Hz)
   x=xy[0][0]
   y=xy[0][0]
   radius_px=radius/(header['CDELT2']*3600.0)
   inner_radius_px=0.1/(header['CDELT2']*3600.0)
   nx=header['NAXIS1']
   ny=header['NAXIS2'] 
   nchans=header['NAXIS3']
   freq_axis=np.zeros(nchans)
   spectrum=np.zeros(nchans)
   e_spectrum=np.zeros(nchans)
   spectrum_beam=np.zeros(nchans)
   e_spectrum_beam=np.zeros(nchans)
   mask,inverse_mask=create_circular_mask2(nx,ny, center=[x,y], radius=radius_px)
   inner_mask,inverse_inner_mask=create_circular_mask2(nx,ny, center=[x,y], radius=inner_radius_px)
   if exclude_inner:
      mask=mask-inner_mask
   mask=mask.transpose()
   inverse_mask=inverse_mask.transpose()
   npix_beam=3.14159*header['BMAJ']*header['BMIN']/(4.0*math.log(2.0))/header['CDELT2']**2
   nbeam_per_spectrum=3.14159*radius_px**2/npix_beam
   for i in range(nchans):
      freq_axis[i]=header['CRVAL3']+ i*header['CDELT3']
      spectrum[i]=np.sum(image[0].data[i,:,:]*mask/npix_beam)
      spectrum_beam[i]=np.sum(image[0].data[i,:,:]*mask)
      indices=(inverse_mask==1).nonzero()
      e_spectrum[i]=rms(image[0].data[i,indices[0],indices[1]])*nbeam_per_spectrum**0.5/npix_beam
      e_spectrum_beam[i]=rms(image[0].data[i,indices[0],indices[1]])*nbeam_per_spectrum**0.5
      image[0].data[i,:,:]=image[0].data[i,:,:]*mask
   #image.writeto(imagename.replace('fits','masked.fits'))
   return freq_axis, spectrum,e_spectrum,spectrum_beam,e_spectrum_beam,nbeam_per_spectrum,npix_beam


def create_circular_mask(h, w, center=None, radius=None):

    if center is None: # use the middle of the image
        center = (int(w/2), int(h/2))
    if radius is None: # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w-center[0], h-center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)

    mask = dist_from_center <= radius
    return mask

def create_circular_mask2(nx, ny, center=None, radius=None):
   x = np.linspace(0,nx-1, nx)
   y = np.linspace(0,ny-1, ny)
   x, y = np.meshgrid(x, y)
   mask = np.sqrt((x - center[0])**2 + (y-center[1])**2)
   inverse_mask = np.sqrt((x - center[0])**2 + (y-center[1])**2)
   for x in range(0,nx):
        for y in range(0,ny):
                if mask[x,y] < radius:
                        mask[x,y] = 1.00
                        inverse_mask[x,y] = 0
                elif mask[x,y] >= radius:
                        mask[x,y] = 0
                        inverse_mask[x,y] = 1.0
   return mask,inverse_mask




spectral_dict={}
#spectral_dict['C18O']={'filename': 'V883_Ori_C18O-21_image.image.fits','title': 'C$^{18}$O (J=2-1)'}
spectral_dict['C17O']={'filename': 'V883_Ori_SB_C17O-224.711GHz_robust_2.0.image.fits','title': 'C$^{17}$O (J=2-1)'}
#spectral_dict['C17O_full']={'filename': 'V883_Ori_SB_C17O-224.711GHz-fullcube_robust_2.0.image.fits','title': 'C$^{17}$O (J=2-1) Full'}
spectral_dict['CH3OH_241']={'filename': 'V883_Ori_SB_CH3OH-241.846GHz-fullcube_robust_2.0.image.fits','title': 'CH$_3$OH 241 GHz'}
#spectral_dict['HDO_241']={'filename': 'V883_Ori_SB_HDO-241.558GHz_robust_2.0.image.fits','title': 'HDO 241 GHz'}
#spectral_dict['HDO_225']={'filename': 'V883_Ori_SB_HDO-225.535GHz_robust_2.0.image.fits','title': 'HDO 225 GHz'}
spectral_dict['HDO_241']={'filename': 'V883_Ori_SB_HDO-241.558GHz_robust_2.0.image.fits','title': 'HDO 241 GHz'}
spectral_dict['HDO_225']={'filename': 'V883_Ori_SB_HDO-225.535GHz_robust_2.0.image.fits','title': 'HDO 225 GHz'}
spectral_dict['H218O_203']={'filename': 'V883_Ori_SB_H218O-203.GHz-0.4kms_robust_2.0.image.fits'} #'V883_Ori_SB_H218O-203.GHz_robust_2.0.image.fits'} 

#spectral_dict['HDO_241_full']={'filename': 'V883_Ori_SB_HDO-241.558GHz-fullcube_robust_2.0.image.fits','title': 'HDO 241 GHz Full'}
#spectral_dict['HDO_225_full']={'filename': 'V883_Ori_SB_HDO-225.535GHz-fullcube_robust_2.0.image.fits','title': 'HDO 225 GHz Full'}
#spectral_dict['H2CO']={'filename': 'V883_Ori_SB_H2CO-225.694GHz-fullcube_robust_2.0.image.fits','title': 'H$_2$CO 225 GHz'}
#spectral_dict['continuum']={'filename': 'V883_Ori_SB_wide-2GHz-240.496GHz-fullcube_robust_2.0.image.fits','title': 'Continuum 240 GHz'}
#spectral_dict['CH3OD']={'filename': 'V883_Ori_SB_CH3OD-226.535GHz-fullcube_robust_2.0.image.fits','title': 'CH$_3$OD'}
#spectral_dict['CN_225.65']={'filename': 'V883_Ori_SB_CN-225.656GHz-fullcube_robust_2.0.image.fits','title': 'CN 225.65 GHz'}
#spectral_dict['CN_225.87']={'filename': 'V883_Ori_SB_CN-225.871GHz-fullcube_robust_2.0.image.fits','title': 'CN 225.87 GHz'}

radecstring='05h38m18.100454s -07d02m25.99340s'
coord=SkyCoord([radecstring],frame='icrs',unit=(u.hourangle,u.deg))



for key in spectral_dict.keys():
   print(key)
   #if (('C17O' in key) or (key =='H2CO') or ('CN' in key)):
   if (('CN' in key) or ('C18O' in key)):
      radius=1.0
   else:
      radius=0.4
   spectral_dict[key]['freq_axis'],spectral_dict[key]['spectrum'],\
   spectral_dict[key]['e_spectrum'],spectral_dict[key]['spectrum_Jy_beam'],\
   spectral_dict[key]['e_spectrum_Jy_beam'],spectral_dict[key]['nbeams'],spectral_dict[key]['npix_per_beam']=get_spectrum(spectral_dict[key]['filename'],\
                                                                                      coord,radius,exclude_inner=False)



for key in spectral_dict.keys():
   fig,ax=plt.subplots(1,1,figsize=(15, 6))
   ax.plot(spectral_dict[key]['freq_axis']/1e9,spectral_dict[key]['spectrum'],color='black',drawstyle='steps',linewidth='1')
   ax.xaxis.set_major_locator(MultipleLocator(0.03))
   ax.xaxis.set_minor_locator(AutoMinorLocator())
   ax.tick_params(which='both', width=2)
   ax.tick_params(which='major', length=4)
   ax.tick_params(which='minor', length=2)
   ax.set_title(key)
   ax.set_ylabel('Flux Density (Jy)')
   ax.set_xlabel('Frequency (GHz)',labelpad=-4)
   ax.set_xlim(np.min(spectral_dict[key]['freq_axis']/1e9),np.max(spectral_dict[key]['freq_axis']/1e9))
   plt.savefig(key+'_spectrum.png')


pickle.dump(spectral_dict, open("spectral_dict.pickle", "wb"))




spectral_dict=pickle.load(open("spectral_dict.pickle", "rb"))
vlsr=0.0
for key in list(spectral_dict.keys()):
   spectral_dict[key]['freq_axis']=spectral_dict[key]['freq_axis']+vlsr/3.0e5*np.median(spectral_dict[key]['freq_axis'])
   outfile=open(key+'_text_spectra.txt','w')
   for i in range(len(spectral_dict[key]['freq_axis'])):
      line='{} {} {} \n'.format(spectral_dict[key]['freq_axis'][i]/1.0e6,spectral_dict[key]['spectrum'][i],spectral_dict[key]['e_spectrum'][i])
      outfile.writelines(line)
   outfile.close()
      




